# Copyright (c) 2020, Huawei Technologies.All rights reserved.
#
# Licensed under the BSD 3-Clause License  (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch


def box_dtype_check(box):
    if box not in [torch.float, torch.half]:
        return box.float()


def npu_iou(boxes1,
            boxes2,
            mode="ptiou",
            is_normalized=False,
            normalized_scale=100.,
            ):
    """ Applies an NPU based IOU operation.

    Given two lists of boxes of size N and M,
    compute the IoU (intersection over union)
    between all N x M pairs of boxes.
    The box order must be (xmin, ymin, xmax, ymax).

    Compute Function:
    iou = (overlap_area + 0.001) / (union_area + 0.001)
    ptiou = overlap_area / (union_area + 0.001)

    .. note::
        This function is commonly used when bbox and anchor match.
        Until now, this function has no corresponding backward operator,
        so it cannot be used in IOU_Loss.

        Since 0.001 is added to the denominator in the calculation formula to avoid dividing by 0,
        when the input boxes are normalized data, the component of 0.001 will be too heavy.
        At this time, it is necessary to enlarge the input value to avoid excessive influence of 0.001.

    Examples::
    >>> box1 = torch.randint(0, 256, size=(32, 4))
    >>> box2 = torch.randint(0, 256, size=(16, 4))
    >>> iou1 = npu_iou(box1, box2) # (32, 16)

    Args:
        boxes1(N,4),boxes2(M,4): two `Boxes`. Contains N & M boxes, respectively. Support dtype: float, half.
        mode (String): Select the calculation mode of iou. Default ptiou.
        is_normalized (Bool): Whether the value of coordinates has been normalized. Default False.
        normalized_scale (Float): Sets the normalization scale for restoring coordinates. Default 100.

    Returns:
        Tensor: IoU, sized [N,M].
    """

    assert mode in ["iou", "ptiou"]

    boxes1 = box_dtype_check(boxes1)
    boxes2 = box_dtype_check(boxes2)

    if is_normalized:
        boxes1 = boxes1 * normalized_scale
        boxes2 = boxes2 * normalized_scale

    if mode == "iou":
        out = torch.npu_iou(boxes2, boxes1)
    elif mode == "ptiou":
        out = torch.npu_ptiou(boxes2, boxes1)

    return out


npu_ptiou = npu_iou


def npu_giou(boxes1,
             boxes2,
             is_permuted=True,
             ):
    """ Applies an NPU based GIOU operation.

    Given two lists of boxes of size N and M,
    compute the IoU (intersection over union)
    between all N x M pairs of boxes.
    The box order must be (xmin, ymin, xmax, ymax).

    Compute Function:
    iou = overlap_area / union_area
    enclose_area = (max(x2) - min(x1)) * (max(y2) - min(y1))
    giou = iou - (enclose_area - union_area) / enclose_area

    .. note::
        This function is corresponding to a backward operator,
        so it can be used in IOU_Loss.

        Util now, only trans=True(only support xywh, not support xyxy),
        is_cross=False(only support boxes1.shape == boxes2.shape -- One-to-one calculation, not support ((n,4), (m,4)))
        in torch.npu_giou is supported, please don't use other pram.

    Examples::
    >>> box1 = torch.randn(32, 4)
    >>> box1.requires_grad = True
    >>> box2 = torch.randn(32, 4)
    >>> iou1 = npu_giou(box1, box2) # (32, 1)
    >>> l = iou1.sum()
    >>> l.backward()

    Args:
        boxes1 (Tensor): Predicted bboxes of format xywh, shape (n, 4).
        boxes2 (Tensor): Corresponding gt bboxes, shape (n, 4).
        is_permuted (Bool): Whether the value of coordinates has been normalized. Default True.

    Returns:
        Tensor: IoU, sized [n, 1].

    .. _Generalized Intersection over Union\: A Metric and A Loss for Bounding Box Regression:
        https://arxiv.org/abs/1902.09630
    """

    assert boxes1.shape == boxes2.shape

    boxes1 = box_dtype_check(boxes1)
    boxes2 = box_dtype_check(boxes2)

    if is_permuted:
        boxes1 = boxes1.permute(1, 0)
        boxes2 = boxes2.permute(1, 0)

    out = torch.npu_giou(boxes1, boxes2, trans=True, is_cross=False)

    return out


if __name__ == "__main__":
    torch.npu.set_device(0)

    box1 = torch.FloatTensor([[10, 55, 85, 160]])
    box2 = torch.FloatTensor([[18, 45, 80, 130], [38, 85, 70, 230]])
    box1 = box1.float().npu()
    box2 = box2.float().npu()
    iou1 = npu_iou(box1, box2, mode="iou")
    iou2 = npu_iou(box1, box2)
    print(iou1.shape, iou1.max(), iou1.min())
    print(iou2.shape, iou2.max(), iou2.min())

    box1 = torch.FloatTensor([[10, 55, 85, 160]])
    box2 = torch.FloatTensor([[18, 45, 80, 130], [38, 85, 70, 230]])
    box1 = box1.float().npu() / 100.
    box2 = box2.float().npu() / 100.
    iou1 = npu_iou(box1, box2, mode="iou", is_normalized=True, normalized_scale=100.)
    iou2 = npu_iou(box1, box2, is_normalized=True, normalized_scale=100.)
    print(iou1.shape, iou1.max(), iou1.min())
    print(iou2.shape, iou2.max(), iou2.min())

    N = 32
    M = 32 * 32
    box1 = torch.randint(0, 256, size=(N, 4))
    box2 = torch.randint(0, 256, size=(M, 4))
    box1 = box1.float().npu()
    box2 = box2.float().npu()
    iou1 = npu_iou(box1, box2, mode="iou")
    iou2 = npu_iou(box1, box2)
    print(iou1.shape, iou1.max(), iou1.min())
    print(iou2.shape, iou2.max(), iou2.min())

    N = 32
    M = N
    box1 = torch.randn(N, 4)
    box1.requires_grad = True
    box2 = torch.randn(M, 4)
    box1 = box1.float().npu()
    box2 = box2.float().npu()
    iou1 = npu_giou(box1, box2)
    l = iou1.sum()
    l.backward()
    print(iou1.shape, iou1.max(), iou1.min())
    print(iou2.shape, iou2.max(), iou2.min())
